import os
import re
from typing import Union

import pytest
from _pytest.monkeypatch import MonkeyPatch
from requests import Response
from requests.exceptions import ConnectionError
from requests.sessions import Session
from xinference_client.client.restful.restful_client import (
    Client,
    RESTfulChatglmCppChatModelHandle,
    RESTfulChatModelHandle,
    RESTfulEmbeddingModelHandle,
    RESTfulGenerateModelHandle,
    RESTfulRerankModelHandle,
)
from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage


class MockXinferenceClass:
    def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
        if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url):
            raise RuntimeError('404 Not Found')
        
        if 'generate' == model_uid:
            return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
        if 'chat' == model_uid:
            return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
        if 'embedding' == model_uid:
            return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
        if 'rerank' == model_uid:
            return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
        raise RuntimeError('404 Not Found')
        
    def get(self: Session, url: str, **kwargs):
        response = Response()
        if 'v1/models/' in url:
            # get model uid
            model_uid = url.split('/')[-1] or ''
            if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
                model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
                response.status_code = 404
                response._content = b'{}'
                return response

            # check if url is valid
            if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
                response.status_code = 404
                response._content = b'{}'
                return response
            
            if model_uid in ['generate', 'chat']:
                response.status_code = 200
                response._content = b'''{
                    "model_type": "LLM",
                    "address": "127.0.0.1:43877",
                    "accelerators": [
                        "0",
                        "1"
                    ],
                    "model_name": "chatglm3-6b",
                    "model_lang": [
                        "en"
                    ],
                    "model_ability": [
                        "generate",
                        "chat"
                    ],
                    "model_description": "latest chatglm3",
                    "model_format": "pytorch",
                    "model_size_in_billions": 7,
                    "quantization": "none",
                    "model_hub": "huggingface",
                    "revision": null,
                    "context_length": 2048,
                    "replica": 1
                }'''
                return response
            
            elif model_uid == 'embedding':
                response.status_code = 200
                response._content = b'''{
                    "model_type": "embedding",
                    "address": "127.0.0.1:43877",
                    "accelerators": [
                        "0",
                        "1"
                    ],
                    "model_name": "bge",
                    "model_lang": [
                        "en"
                    ],
                    "revision": null,
                    "max_tokens": 512
                }'''
                return response
            
        elif 'v1/cluster/auth' in url:
            response.status_code = 200
            response._content = b'''{
                "auth": true
            }'''
            return response
        
    def _check_cluster_authenticated(self):
        self._cluster_authed = True
        
    def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int) -> dict:
        # check if self._model_uid is a valid uuid
        if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
            self._model_uid != 'rerank':
            raise RuntimeError('404 Not Found')
        
        if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url):
            raise RuntimeError('404 Not Found')

        if top_n is None:
            top_n = 1

        return {
            'results': [
                {
                    'index': i,
                    'document': doc,
                    'relevance_score': 0.9
                }
                for i, doc in enumerate(documents[:top_n])
            ]
        }
        
    def create_embedding(
        self: RESTfulGenerateModelHandle,
        input: Union[str, list[str]],
        **kwargs
    ) -> dict:
        # check if self._model_uid is a valid uuid
        if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
            self._model_uid != 'embedding':
            raise RuntimeError('404 Not Found')

        if isinstance(input, str):
            input = [input]
        ipt_len = len(input)

        embedding = Embedding(
            object="list",
            model=self._model_uid,
            data=[
                EmbeddingData(
                    index=i,
                    object="embedding",
                    embedding=[1919.810 for _ in range(768)]
                )
                for i in range(ipt_len)
            ],
            usage=EmbeddingUsage(
                prompt_tokens=ipt_len,
                total_tokens=ipt_len
            )
        )

        return embedding

MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'

@pytest.fixture
def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
    if MOCK:
        monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
        monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
        monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
        monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
        monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
    yield

    if MOCK:
        monkeypatch.undo()